{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to SafePathNet\n",
    "\n",
    "In this notebook you are going to train the multimodal prediction model presented in [Safe Real-World Autonomous Driving by Learning to Predict and Plan with a Mixture of Experts](https://arxiv.org/abs/2211.02131).\n",
    "Please note that we are releasing the prediction part of our approach only, but the code can be easily extended to planning too.\n",
    "\n",
    "You will train and test your model using the Woven by Toyota Prediction Dataset and [L5Kit](https://github.com/woven-planet/l5kit).\n",
    "**Before starting, please download the [Woven by Toyota Prediction Dataset 2020](https://woven.toyota/en/prediction-dataset) and follow [the instructions](https://github.com/woven-planet/l5kit#download-the-datasets) to correctly organise it.**\n",
    "\n",
    "### Model\n",
    "\n",
    "From the paper:\n",
    "```\n",
    "The architecture of SafePathNet is similar to those of VectorNet [11] and DETR [7], combining an element-wise point encoder [23] and a Transformer [31]. The element-wise point encoder consists of two PointNet-like modules that are used to compress each input element from a set of points to a single feature vector of the same size. A series of Transformer Encoder layers are used to model the relationships between all input elements (SDV, road agents, static and dynamic map, route), encoded by the point encoder. Then, a series of Transformer Decoders are used to query agents features. We make use of a set of learnable embeddings to construct the queries of the Transformer Decoders. M learnable query embeddings are used to obtain a variable number of M different queries for each road agent. An agent-specific MLP decoder converts each agent feature to a future trajectory. In addition to trajectories, the decoder predicts a logit for each agent trajectory. For each element, the corresponding logits are converted to a probability distribution over the future trajectories by applying a softmax function. All road agents are modeled independently, but predicted jointly in parallel.\n",
    "```\n",
    "This is a diagram of the full model:\n",
    "\n",
    "![model](../../docs/images/safepathnet/safepathnet_model.svg)\n",
    "\n",
    "\n",
    "#### Inputs\n",
    "Following previous works, SafePathNet is based on a vectorized representation of the world, centered on the ego location.\n",
    "Please refer to the paper for more details.\n",
    "\n",
    "\n",
    "#### Outputs\n",
    "SafePathNet outputs a trajectory distribution (in the form of a set of trajectories and a probability distribution over them) for each road agent (including ego).\n",
    "All road agents are modeled independently, but predicted jointly in parallel.\n",
    "Each timestep is a tuple consisting of `(X, Y, yaw)`.\n",
    "\n",
    "### Training\n",
    "\n",
    "Our model represents a mixture of experts, comprised of a set of experts and an expert selection function.\n",
    "We train them jointly while avoiding mode collapse using a winner-takes-all approach.\n",
    "\n",
    "From the paper:\n",
    "```\n",
    "Our model represents a MoE and predicts multiple trajectories for the SDV and each road agent, corresponding to N/M experts, and a probability distribution over each trajectory set, corresponding to expert selection. To train the experts and expert selection jointly while avoiding mode collapse, we use a winner-takes-all approach, somewhat similar to previous methods [10]. Similarly to DETR [7], we formulate a matching cost between predicted and target trajectories and probabilities, making the expert with minimal cost the winner.\n",
    "```\n",
    "\n",
    "We define our training objective as minimizing the distance between predicted and ground truth agents’ future trajectories (imitation loss) and the negative log likelihood of the selected trajectory (matching loss).\n",
    "Please refer to the paper for more information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages\n",
    "Import packages (requires a working installation of l5kit) and set random seeds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "from tempfile import gettempdir\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import ChunkedDataset, LocalDataManager\n",
    "from l5kit.dataset import EgoAgentDatasetVectorized\n",
    "from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset\n",
    "from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS\n",
    "from l5kit.evaluation.metrics import average_displacement_error_oracle, final_displacement_error_oracle\n",
    "from l5kit.planning.vectorized.common import build_matrix, transform_points\n",
    "from l5kit.prediction.vectorized.safepathnet_model import SafePathNetModel\n",
    "from l5kit.vectorization.vectorizer_builder import build_vectorizer\n",
    "\n",
    "torch.manual_seed(123)\n",
    "np.random.seed(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (vectorizer, training params...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download L5 Sample Dataset\n",
    "import os\n",
    "RunningInCOLAB = 'google.colab' in str(get_ipython())\n",
    "if RunningInCOLAB:\n",
    "    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q\n",
    "    !sh ./setup_notebook_colab.sh\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = open(\"./dataset_dir.txt\", \"r\").read().strip()\n",
    "else:\n",
    "    print(\"Not running in Google Colab.\")\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = \"PATH_TO_DATASET\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define local data manager\n",
    "dm = LocalDataManager(None)\n",
    "\n",
    "# load the experiment config\n",
    "cfg = load_config_data(\"./config.yaml\")\n",
    "print(\"Configuration loaded.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initialize the training dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# INIT DATASET\n",
    "train_zarr = ChunkedDataset(dm.require(cfg[\"train_data_loader\"][\"key\"])).open()\n",
    "\n",
    "vectorizer = build_vectorizer(cfg, dm)\n",
    "train_dataset = EgoAgentDatasetVectorized(cfg, train_zarr, vectorizer)\n",
    "print(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the model\n",
    "Let's define the SafePathNet model and move it to GPU, if available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SafePathNetModel(\n",
    "    history_num_frames_ego=cfg[\"model_params\"][\"history_num_frames_ego\"],\n",
    "    history_num_frames_agents=cfg[\"model_params\"][\"history_num_frames_agents\"],\n",
    "    num_timesteps=cfg[\"model_params\"][\"future_num_frames\"],\n",
    "    weights_scaling=cfg[\"model_params\"][\"weights_scaling\"],\n",
    "    criterion=nn.L1Loss(reduction=\"none\"),\n",
    "    disable_other_agents=cfg[\"model_params\"][\"disable_other_agents\"],\n",
    "    disable_map=cfg[\"model_params\"][\"disable_map\"],\n",
    "    disable_lane_boundaries=cfg[\"model_params\"][\"disable_lane_boundaries\"],\n",
    "    agent_num_trajectories=cfg[\"model_params\"][\"agent_num_trajectories\"],\n",
    "    max_num_agents=cfg[\"data_generation_params\"][\"other_agents_num\"],\n",
    "    cost_prob_coeff=cfg[\"model_params\"][\"cost_prob_coeff\"] * 2.5,\n",
    ")\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Model created and loaded on device:\", device)\n",
    "\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare for training\n",
    "Our `EgoAgentDatasetVectorized` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing.\n",
    "It extends the dataset `EgoDatasetVectorized` to include ego as a road agent and to support agent prediction evaluation, while keeping the scene ego-centric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg = cfg[\"train_data_loader\"]\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=train_cfg[\"shuffle\"], batch_size=train_cfg[\"batch_size\"],\n",
    "                              num_workers=train_cfg[\"num_workers\"])\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "num_epochs = cfg[\"train_params\"][\"num_epochs\"]\n",
    "num_steps_per_epoch = len(train_dataset)\n",
    "max_num_steps = min(num_steps_per_epoch, cfg[\"train_params\"][\"max_num_steps\"])\n",
    "num_steps_per_log = max(1, max_num_steps // 100)\n",
    "checkpoint_every_n_epochs = cfg[\"train_params\"][\"checkpoint_every_n_epochs\"]\n",
    "num_warmup_epochs = cfg[\"train_params\"][\"num_epochs\"] // 5\n",
    "\n",
    "def lr_lambda_warmup_cosine(step: int) -> float:\n",
    "    steps_per_epoch = max_num_steps\n",
    "    total_steps = num_epochs * steps_per_epoch\n",
    "    warmup_steps = num_warmup_epochs * steps_per_epoch\n",
    "\n",
    "    if step < warmup_steps:  # warmup\n",
    "        return step / warmup_steps\n",
    "    else:\n",
    "        steps_since_warmup = step - warmup_steps\n",
    "        anneal_steps = total_steps - warmup_steps\n",
    "        completion = steps_since_warmup / anneal_steps\n",
    "        cosine_rate = float(0.5 * (1 + np.cos(completion * np.pi)))\n",
    "        return cosine_rate\n",
    "\n",
    "lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_warmup_cosine)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop\n",
    "Here, we purposely include a basic training loop. Clearly, many more components can be added to enrich logging and improve performance. Still, a reasonable performance can be obtained even with this simple loop.\n",
    "Please adapt the training length changing the \"train_params\" in the config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "model.train()\n",
    "torch.set_grad_enabled(True)\n",
    "\n",
    "loss_log = defaultdict(list)\n",
    "lr_log = list()\n",
    "progress_bar = tqdm(total=max_num_steps)\n",
    "\n",
    "print(f\"Starting training - {num_epochs} epochs\")\n",
    "print(f\"An epoch is composed of {len(train_dataset)} steps\")\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    print(f\"Epoch {epoch + 1} - Starting\")\n",
    "    progress_bar.reset()\n",
    "    for idx, data in enumerate(train_dataloader):\n",
    "        if idx == max_num_steps:\n",
    "            break\n",
    "\n",
    "        # Forward pass\n",
    "        data = {k: v.to(device) for k, v in data.items()}\n",
    "        result = model(data)\n",
    "        loss = result[\"loss\"]\n",
    "\n",
    "        # Backward pass\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "        # logging\n",
    "        if idx % num_steps_per_log == 0:\n",
    "            for key, res in result.items():\n",
    "                loss_log[key].append(res.item())\n",
    "            lr_log.append(lr_scheduler.get_last_lr())\n",
    "        \n",
    "        progress_bar.update()\n",
    "        progress_bar.set_description(f\"loss: {loss.item():.5f} - loss(avg): {np.mean(loss_log['loss'][-idx:]):.5f}\")\n",
    "    \n",
    "    if epoch % checkpoint_every_n_epochs == 0 or epoch + 1 == num_epochs:\n",
    "        save_path = f\"{gettempdir()}/safepathnet_model.{epoch}.pth\"\n",
    "        torch.save(model.state_dict(), save_path)\n",
    "        print(f\"Model saved at {save_path}.\")\n",
    "    \n",
    "    for key, loss in loss_log.items():\n",
    "        loss_last_epoch = loss[-idx // num_steps_per_log:]\n",
    "        plt.plot(np.arange(len(loss_last_epoch)), loss_last_epoch, label=key)\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    \n",
    "    lr_log_last_epoch = lr_log[-idx // num_steps_per_log:]\n",
    "    plt.plot(np.arange(len(lr_log_last_epoch)), lr_log_last_epoch, label='learning rate')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    \n",
    "    print(f\"Epoch {epoch + 1} - Ended\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the final train loss curve\n",
    "We can plot the train loss against the iterations (max 100 values per epoch) to check if our model has converged.\n",
    "We also plot the learning rate values used across the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "for key, loss in loss_log.items():\n",
    "    plt.plot(np.arange(len(loss)), loss, label=key)\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "plt.plot(np.arange(len(lr_log)), lr_log, label='learning rate')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Store the model in torchscript format\n",
    "\n",
    "Let's store the model as a torchscript too. This format allows us to re-load the model and weights without requiring the class definition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "jit_model = torch.jit.script(model.cpu())\n",
    "path_to_save = f\"{gettempdir()}/safepathnet_script.pth\"\n",
    "jit_model.save(path_to_save)\n",
    "print(f\"MODEL STORED at {path_to_save}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation\n",
    "\n",
    "Following the challenge evaluation protocol, **the test set for the competition is \"chopped\" using the `chop_dataset` function**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GENERATE AND LOAD CHOPPED DATASET\n",
    "num_frames_to_chop = 100\n",
    "eval_cfg = cfg[\"val_data_loader\"]\n",
    "eval_base_path = os.path.join(os.environ[\"L5KIT_DATA_FOLDER\"], f\"{eval_cfg['key'].split('.')[0]}_chopped_100\")\n",
    "if not os.path.exists(eval_base_path):\n",
    "    eval_base_path = create_chopped_dataset(\n",
    "        dm.require(eval_cfg[\"key\"]), \n",
    "        cfg[\"raster_params\"][\"filter_agents_threshold\"], \n",
    "        num_frames_to_chop, \n",
    "        cfg[\"model_params\"][\"future_num_frames\"], \n",
    "        MIN_FUTURE_STEPS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result is that **each scene has been reduced to only 100 frames**, and **only valid agents in the 100th frame will be used to compute the metrics**. Because following frames in the scene have been chopped off, we can't just look ahead to get the future of those agents.\n",
    "\n",
    "In this example, we simulate this pipeline by running `chop_dataset` on the validation set. The function stores:\n",
    "- a new chopped `.zarr` dataset, in which each scene has only the first 100 frames;\n",
    "- a numpy mask array where only valid agents in the 100th frame are True;\n",
    "- a ground-truth file with the future coordinates of those agents;\n",
    "\n",
    "Please note how the total number of frames is now equal to the number of scenes multipled by `num_frames_to_chop`. \n",
    "\n",
    "The remaining frames in the scene have been sucessfully chopped off from the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** SafePathNet is able to predict future trajectories of all the agents in a scene in a single pass of the model, using the ego-centric reference frame.\n",
    "Thus, we use a modified version of the EgoDataset that additionally returns the ids of the agents that are used for evaluation in the Prediction Challenge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_zarr_path = str(Path(eval_base_path) / Path(dm.require(eval_cfg[\"key\"])).name)\n",
    "eval_mask_path = str(Path(eval_base_path) / \"mask.npz\")\n",
    "eval_gt_path = str(Path(eval_base_path) / \"gt.csv\")\n",
    "\n",
    "eval_zarr = ChunkedDataset(eval_zarr_path).open()\n",
    "eval_mask = np.load(eval_mask_path)[\"arr_0\"]\n",
    "\n",
    "vectorizer = build_vectorizer(cfg, dm)\n",
    "\n",
    "# INIT DATASET AND LOAD MASK\n",
    "eval_dataset = EgoAgentDatasetVectorized(cfg, eval_zarr, vectorizer, agents_mask=eval_mask, eval_mode=True)\n",
    "print(eval_dataset)\n",
    "\n",
    "eval_cfg = cfg[\"val_data_loader\"]\n",
    "eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg[\"shuffle\"], batch_size=eval_cfg[\"batch_size\"],\n",
    "                             num_workers=eval_cfg[\"num_workers\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Storing Predictions\n",
    "There is a small catch to be aware of when saving the model predictions. The output of the models are coordinates in `ego` space and we need to convert them into displacements in `world` space.\n",
    "\n",
    "To do so, we first convert them back into the `world` space and we then subtract to each agent their own `world` centroid coordinates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# EVAL LOOP\n",
    "model.eval()\n",
    "model.to(device)\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# store information for evaluation\n",
    "future_coords_offsets_pd = []\n",
    "future_traj_confidence = []\n",
    "timestamps = []\n",
    "agent_ids = []\n",
    "agent_of_interest_ids = []\n",
    "missing_agent_of_interest_ids = []\n",
    "missing_agent_of_interest_timestamp = []\n",
    "\n",
    "# torch.isin is available only form pytorch 1.10 - defining a simple alternative\n",
    "def torch_isin(ar1, ar2):\n",
    "    return (ar1[..., None] == ar2).any(-1)\n",
    "\n",
    "# iterate over validation dataset\n",
    "progress_bar = tqdm(eval_dataloader)\n",
    "for data in progress_bar:\n",
    "    data = {k: v.to(device) for k, v in data.items()}\n",
    "    outputs = model(data)\n",
    "\n",
    "    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 2]\n",
    "    agent_xy = outputs[\"all_agent_positions\"]\n",
    "    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 1]\n",
    "    agent_yaw = outputs[\"all_agent_yaws\"]\n",
    "    # [batch_size, max_num_agents, num_trajectories]\n",
    "    agent_logits = outputs[\"agent_traj_logits\"]\n",
    "    \n",
    "    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 3]\n",
    "    agent_pos = torch.cat((agent_xy, agent_yaw), dim=-1)\n",
    "\n",
    "    # ego-centric agent coords must be converted to world frame first\n",
    "    # [batch_size, 3, 3]\n",
    "    world_from_agents = data[\"world_from_agent\"].float()\n",
    "    # [batch_size]\n",
    "    world_from_agents_yaw = data[\"yaw\"].float()\n",
    "    # shape of data[\"all_other_agents_history_positions\"]: [batch_size, num_agents, num_history_frames, 2]\n",
    "    # [batch_size, num_agents, 1, 3]\n",
    "    agent_t0_pos_yaw = torch.cat((data[\"all_other_agents_history_positions\"][:, :, :1],\n",
    "                                  data[\"all_other_agents_history_yaws\"][:, :, :1]), dim=-1)\n",
    "    agent_t0_avail = data[\"all_other_agents_history_availability\"][:, :, :1]\n",
    "    # [batch_size, num_agents, 1, 3]\n",
    "    world_agent_t0_pos_yaw = transform_points(agent_t0_pos_yaw, world_from_agents, avail=agent_t0_avail,\n",
    "                                              yaw=world_from_agents_yaw)\n",
    "    world_agent_pos = transform_points(agent_pos.flatten(2,3), world_from_agents, avail=agent_t0_avail).view_as(agent_pos)\n",
    "\n",
    "    # then can be converted to agent-relative\n",
    "    world_agents_t0_pos_exp = world_agent_t0_pos_yaw[..., :2]\n",
    "    world_agents_t0_yaw_exp = world_agent_t0_pos_yaw[..., 2]\n",
    "    # [batch_size * max_num_agents, 3, 3]\n",
    "    _, matrix = build_matrix(world_agents_t0_pos_exp.reshape(-1, 2), world_agents_t0_yaw_exp.reshape(-1))\n",
    "    # [batch_size, max_num_agents, 3, 3]\n",
    "    matrix = matrix.view(list(world_agent_t0_pos_yaw.shape[:2]) + [3, 3])\n",
    "    # [batch_size * max_num_agents * num_trajectories * num_timesteps, 3, 3]\n",
    "    matrix = matrix.unsqueeze(2).unsqueeze(2).expand(list(agent_pos.shape[:-1]) + [3, 3]).reshape(-1, 3, 3)\n",
    "    coords_offset = transform_points(world_agent_pos.reshape(-1, 1, 1, 3), matrix, \n",
    "                                     avail=torch.ones_like(world_agent_pos.reshape(-1, 1, 1, 3)[..., 0]))\n",
    "    coords_offset = coords_offset.view_as(world_agent_pos)\n",
    "    \n",
    "    # need to filter per agents of interest (from original prediction evaluation)\n",
    "    agents_track_ids = data[\"all_other_agents_track_ids\"]\n",
    "    agents_of_interest = data[\"all_valid_agents_track_ids\"]\n",
    "    agents_track_ids_mask = torch.zeros_like(agents_track_ids, dtype=torch.bool)\n",
    "    missing_agents_mask = torch.zeros_like(agents_of_interest, dtype=torch.bool)\n",
    "    for batch_idx in range(agents_track_ids.shape[0]):\n",
    "        agents_track_ids_mask[batch_idx] = torch_isin(agents_track_ids[batch_idx], agents_of_interest[batch_idx]) * \\\n",
    "                                           agents_track_ids[batch_idx] != 0\n",
    "        missing_agents_mask[batch_idx] = ~torch_isin(agents_of_interest[batch_idx], agents_track_ids[batch_idx]) * \\\n",
    "                                         agents_of_interest[batch_idx] != 0\n",
    "    # we may miss some agents due to the limit cfg[\"data_generation_params\"][\"other_agents_num\"], we will consider them stationary\n",
    "    missing_agents_ids = agents_of_interest[missing_agents_mask]\n",
    "    if torch.any(missing_agents_mask):\n",
    "        # print(len(missing_agents_ids), missing_agents_ids[missing_agents_ids != 0])\n",
    "        missing_agents_ids = missing_agents_ids[missing_agents_ids != 0]\n",
    "        missing_agent_of_interest_ids.append(missing_agents_ids.cpu())\n",
    "        missing_timestamps = []\n",
    "        for batch_idx, num_missing_agents in enumerate(missing_agents_mask.sum(-1)):\n",
    "            missing_timestamps.extend([data[\"timestamp\"][batch_idx]] * num_missing_agents)\n",
    "        missing_agent_of_interest_timestamp.append(torch.tensor(missing_timestamps))\n",
    "    \n",
    "    # move the valida data to CPU\n",
    "    relevant_coords_offset = coords_offset[agents_track_ids_mask].cpu()\n",
    "    traj_confidence = agent_logits[agents_track_ids_mask].cpu()\n",
    "    relevant_agent_track_ids = agents_track_ids[agents_track_ids_mask].cpu()\n",
    "    relevant_timestamps = data[\"timestamp\"].unsqueeze(1).expand(agents_track_ids.shape)[agents_track_ids_mask].cpu()\n",
    "\n",
    "    # add them to the result lists\n",
    "    future_coords_offsets_pd.append(relevant_coords_offset)\n",
    "    future_traj_confidence.append(traj_confidence)\n",
    "    timestamps.append(relevant_timestamps)\n",
    "    agent_ids.append(relevant_agent_track_ids)\n",
    "\n",
    "# add the missing agents as stationary\n",
    "missing_agent_of_interest_ids = torch.cat(missing_agent_of_interest_ids, dim=0)\n",
    "missing_agent_of_interest_timestamp = torch.cat(missing_agent_of_interest_timestamp, dim=0)\n",
    "stationary_trajectories = torch.zeros(list(missing_agent_of_interest_ids.shape[:1]) + list(future_coords_offsets_pd[0].shape[1:]))\n",
    "uniform_probabilities = torch.ones(list(missing_agent_of_interest_ids.shape[:1]) + list(future_traj_confidence[0].shape[1:]))\n",
    "agent_ids.append(missing_agent_of_interest_ids)\n",
    "future_coords_offsets_pd.append(stationary_trajectories)\n",
    "future_traj_confidence.append(uniform_probabilities)\n",
    "timestamps.append(missing_agent_of_interest_timestamp)\n",
    "\n",
    "# concatenate all the results in a single np array    \n",
    "future_coords_offsets_pd = torch.cat(future_coords_offsets_pd, dim=0).numpy()\n",
    "future_traj_confidence = torch.cat(future_traj_confidence, dim=0).softmax(-1).numpy()\n",
    "timestamps = torch.cat(timestamps, dim=0).numpy().astype(np.int64)\n",
    "agent_ids = torch.cat(agent_ids, dim=0).numpy().astype(np.int64)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's verify the number of coordinates corresponds to the number of coordinates in the original \n",
    "assert len(future_coords_offsets_pd == 94694)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Overall, we missed {len(missing_agent_of_interest_ids)} agents over a total of {94694} agents \"\n",
    "      f\"(~{len(missing_agent_of_interest_ids)/94694:.5f}%)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save results in csv format\n",
    "After the model has predicted trajectories for our evaluation set, we can save them in a `csv` file compatible with the l5kit evaluation tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_path = f\"{gettempdir()}/pred.csv\"\n",
    "\n",
    "write_pred_csv(pred_path,\n",
    "               timestamps=timestamps,\n",
    "               track_ids=agent_ids,\n",
    "               coords=future_coords_offsets_pd[..., :2],\n",
    "               confs=future_traj_confidence,\n",
    "               max_modes=cfg[\"model_params\"][\"agent_num_trajectories\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform Evaluation\n",
    "We can evaluate the model predictions with the existing metrics from l5kit, supporting multimodal predictions. In our case, we're interested in the minimum Average Distance Error (minADE) and minimum final distance error (minFDE). Other metrics can be added from `l5kit.evaluation.metrics`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you restart the notebook and want to evaluate an existing csv file, uncomment this cell using your csv path\n",
    "# pred_path = 'PATH_TO_CSV'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# COMPUTE AND PRINT METRICS\n",
    "metrics = compute_metrics_csv(eval_gt_path, pred_path, \n",
    "                              [average_displacement_error_oracle, final_displacement_error_oracle],\n",
    "                              max_modes=cfg[\"model_params\"][\"agent_num_trajectories\"])\n",
    "for metric_name, metric_mean in metrics.items():\n",
    "    print(metric_name, metric_mean)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualise Results\n",
    "Coming soon.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations in training and evaluating your SafePathNet model!\n",
    "\n",
    "For more information on SafePathNet, please have a look at our paper  \n",
    "[Safe Real-World Autonomous Driving by Learning to Predict and Plan with a Mixture of Experts](https://arxiv.org/abs/2211.02131)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
